{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tóm tắt (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]\n",
    "!pip install accelerate\n",
    "# To run the training on TPU, you will need to uncomment the following line:\n",
    "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n",
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will need to setup git, adapt your email and name in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git config --global user.email \"you@example.com\"\n",
    "!git config --global user.name \"Your Name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],\n",
       "        num_rows: 200000\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],\n",
       "        num_rows: 5000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],\n",
       "        num_rows: 5000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "spanish_dataset = load_dataset(\"amazon_reviews_multi\", \"es\")\n",
    "english_dataset = load_dataset(\"amazon_reviews_multi\", \"en\")\n",
    "english_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>> Title: Worked in front position, not rear'\n",
       "'>> Review: 3 stars because these are not rear brakes as stated in the item description. At least the mount adapter only worked on the front fork of the bike that I got it for.'\n",
       "\n",
       "'>> Title: meh'\n",
       "'>> Review: Does it’s job and it’s gorgeous but mine is falling apart, I had to basically put it together again with hot glue'\n",
       "\n",
       "'>> Title: Can\\'t beat these for the money'\n",
       "'>> Review: Bought this for handling miscellaneous aircraft parts and hanger \"stuff\" that I needed to organize; it really fit the bill. The unit arrived quickly, was well packaged and arrived intact (always a good sign). There are five wall mounts-- three on the top and two on the bottom. I wanted to mount it on the wall, so all I had to do was to remove the top two layers of plastic drawers, as well as the bottom corner drawers, place it when I wanted and mark it; I then used some of the new plastic screw in wall anchors (the 50 pound variety) and it easily mounted to the wall. Some have remarked that they wanted dividers for the drawers, and that they made those. Good idea. My application was that I needed something that I can see the contents at about eye level, so I wanted the fuller-sized drawers. I also like that these are the new plastic that doesn\\'t get brittle and split like my older plastic drawers did. I like the all-plastic construction. It\\'s heavy duty enough to hold metal parts, but being made of plastic it\\'s not as heavy as a metal frame, so you can easily mount it to the wall and still load it up with heavy stuff, or light stuff. No problem there. For the money, you can\\'t beat it. Best one of these I\\'ve bought to date-- and I\\'ve been using some version of these for over forty years.'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def show_samples(dataset, num_samples=3, seed=42):\n",
    "    sample = dataset[\"train\"].shuffle(seed=seed).select(range(num_samples))\n",
    "    for example in sample:\n",
    "        print(f\"\\n'>> Title: {example['review_title']}'\")\n",
    "        print(f\"'>> Review: {example['review_body']}'\")\n",
    "\n",
    "\n",
    "show_samples(english_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "home                      17679\n",
       "apparel                   15951\n",
       "wireless                  15717\n",
       "other                     13418\n",
       "beauty                    12091\n",
       "drugstore                 11730\n",
       "kitchen                   10382\n",
       "toy                        8745\n",
       "sports                     8277\n",
       "automotive                 7506\n",
       "lawn_and_garden            7327\n",
       "home_improvement           7136\n",
       "pet_products               7082\n",
       "digital_ebook_purchase     6749\n",
       "pc                         6401\n",
       "electronics                6186\n",
       "office_product             5521\n",
       "shoes                      5197\n",
       "grocery                    4730\n",
       "book                       3756\n",
       "Name: product_category, dtype: int64"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "english_dataset.set_format(\"pandas\")\n",
    "english_df = english_dataset[\"train\"][:]\n",
    "# Hiển thị số lượng cho 20 sản phẩm hàng đầu\n",
    "english_df[\"product_category\"].value_counts()[:20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_books(example):\n",
    "    return (\n",
    "        example[\"product_category\"] == \"book\"\n",
    "        or example[\"product_category\"] == \"digital_ebook_purchase\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "english_dataset.reset_format()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>> Title: I\\'m dissapointed.'\n",
       "'>> Review: I guess I had higher expectations for this book from the reviews. I really thought I\\'d at least like it. The plot idea was great. I loved Ash but, it just didnt go anywhere. Most of the book was about their radio show and talking to callers. I wanted the author to dig deeper so we could really get to know the characters. All we know about Grace is that she is attractive looking, Latino and is kind of a brat. I\\'m dissapointed.'\n",
       "\n",
       "'>> Title: Good art, good price, poor design'\n",
       "'>> Review: I had gotten the DC Vintage calendar the past two years, but it was on backorder forever this year and I saw they had shrunk the dimensions for no good reason. This one has good art choices but the design has the fold going through the picture, so it\\'s less aesthetically pleasing, especially if you want to keep a picture to hang. For the price, a good calendar'\n",
       "\n",
       "'>> Title: Helpful'\n",
       "'>> Review: Nearly all the tips useful and. I consider myself an intermediate to advanced user of OneNote. I would highly recommend.'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spanish_books = spanish_dataset.filter(filter_books)\n",
    "english_books = english_dataset.filter(filter_books)\n",
    "show_samples(english_books)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>> Title: Easy to follow!!!!'\n",
       "'>> Review: I loved The dash diet weight loss Solution. Never hungry. I would recommend this diet. Also the menus are well rounded. Try it. Has lots of the information need thanks.'\n",
       "\n",
       "'>> Title: PARCIALMENTE DAÑADO'\n",
       "'>> Review: Me llegó el día que tocaba, junto a otros libros que pedí, pero la caja llegó en mal estado lo cual dañó las esquinas de los libros porque venían sin protección (forro).'\n",
       "\n",
       "'>> Title: no lo he podido descargar'\n",
       "'>> Review: igual que el anterior'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import concatenate_datasets, DatasetDict\n",
    "\n",
    "books_dataset = DatasetDict()\n",
    "\n",
    "for split in english_books.keys():\n",
    "    books_dataset[split] = concatenate_datasets(\n",
    "        [english_books[split], spanish_books[split]]\n",
    "    )\n",
    "    books_dataset[split] = books_dataset[split].shuffle(seed=42)\n",
    "\n",
    "# Chọn ra một vài mẫu\n",
    "show_samples(books_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "books_dataset = books_dataset.filter(lambda x: len(x[\"review_title\"].split()) > 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_checkpoint = \"google/mt5-small\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [336, 259, 28387, 11807, 287, 62893, 295, 12507, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = tokenizer(\"I loved reading the Hunger Games!\")\n",
    "inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁I', '▁', 'loved', '▁reading', '▁the', '▁Hung', 'er', '▁Games', '</s>']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.convert_ids_to_tokens(inputs.input_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_input_length = 512\n",
    "max_target_length = 30\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    model_inputs = tokenizer(\n",
    "        examples[\"review_body\"], max_length=max_input_length, truncation=True\n",
    "    )\n",
    "    # Thiết lập tokenizer cho nhãn\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        labels = tokenizer(\n",
    "            examples[\"review_title\"], max_length=max_target_length, truncation=True\n",
    "        )\n",
    "\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_datasets = books_dataset.map(preprocess_function, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_summary = \"I absolutely loved reading the Hunger Games\"\n",
    "reference_summary = \"I loved reading the Hunger Games\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install rouge_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "rouge_score = evaluate.load(\"rouge\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'rouge1': AggregateScore(low=Score(precision=0.86, recall=1.0, fmeasure=0.92), mid=Score(precision=0.86, recall=1.0, fmeasure=0.92), high=Score(precision=0.86, recall=1.0, fmeasure=0.92)),\n",
       " 'rouge2': AggregateScore(low=Score(precision=0.67, recall=0.8, fmeasure=0.73), mid=Score(precision=0.67, recall=0.8, fmeasure=0.73), high=Score(precision=0.67, recall=0.8, fmeasure=0.73)),\n",
       " 'rougeL': AggregateScore(low=Score(precision=0.86, recall=1.0, fmeasure=0.92), mid=Score(precision=0.86, recall=1.0, fmeasure=0.92), high=Score(precision=0.86, recall=1.0, fmeasure=0.92)),\n",
       " 'rougeLsum': AggregateScore(low=Score(precision=0.86, recall=1.0, fmeasure=0.92), mid=Score(precision=0.86, recall=1.0, fmeasure=0.92), high=Score(precision=0.86, recall=1.0, fmeasure=0.92))}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores = rouge_score.compute(\n",
    "    predictions=[generated_summary], references=[reference_summary]\n",
    ")\n",
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Score(precision=0.86, recall=1.0, fmeasure=0.92)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores[\"rouge1\"].mid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install nltk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nltk\n",
    "\n",
    "nltk.download(\"punkt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'I grew up reading Koontz, and years ago, I stopped,convinced i had \"outgrown\" him.'\n",
       "'Still,when a friend was looking for something suspenseful too read, I suggested Koontz.'\n",
       "'She found Strangers.'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from nltk.tokenize import sent_tokenize\n",
    "\n",
    "\n",
    "def three_sentence_summary(text):\n",
    "    return \"\\n\".join(sent_tokenize(text)[:3])\n",
    "\n",
    "\n",
    "print(three_sentence_summary(books_dataset[\"train\"][1][\"review_body\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_baseline(dataset, metric):\n",
    "    summaries = [three_sentence_summary(text) for text in dataset[\"review_body\"]]\n",
    "    return metric.compute(predictions=summaries, references=dataset[\"review_title\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'rouge1': 16.74, 'rouge2': 8.83, 'rougeL': 15.6, 'rougeLsum': 15.96}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "score = evaluate_baseline(books_dataset[\"validation\"], rouge_score)\n",
    "rouge_names = [\"rouge1\", \"rouge2\", \"rougeL\", \"rougeLsum\"]\n",
    "rouge_dict = dict((rn, round(score[rn].mid.fmeasure * 100, 2)) for rn in rouge_names)\n",
    "rouge_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForSeq2SeqLM\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainingArguments\n",
    "\n",
    "batch_size = 8\n",
    "num_train_epochs = 8\n",
    "# Hiện thị mất mát huấn luyện mỗi epoch\n",
    "logging_steps = len(tokenized_datasets[\"train\"]) // batch_size\n",
    "model_name = model_checkpoint.split(\"/\")[-1]\n",
    "\n",
    "args = Seq2SeqTrainingArguments(\n",
    "    output_dir=f\"{model_name}-finetuned-amazon-en-es\",\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    learning_rate=5.6e-5,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    weight_decay=0.01,\n",
    "    save_total_limit=3,\n",
    "    num_train_epochs=num_train_epochs,\n",
    "    predict_with_generate=True,\n",
    "    logging_steps=logging_steps,\n",
    "    push_to_hub=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # Giải mã các tóm tắt được tạo ra thành văn bản\n",
    "    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n",
    "    # Thay -100 vào nhãn vì ta không thể giải mã chúng\n",
    "    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "    # Giải mã các tóm tắt mẫu thành văn bản\n",
    "    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "    # ROUGE kì vọng dòng mới sau mỗi câu\n",
    "    decoded_preds = [\"\\n\".join(sent_tokenize(pred.strip())) for pred in decoded_preds]\n",
    "    decoded_labels = [\"\\n\".join(sent_tokenize(label.strip())) for label in decoded_labels]\n",
    "    # Tính điểm ROUGE\n",
    "    result = rouge_score.compute(\n",
    "        predictions=decoded_preds, references=decoded_labels, use_stemmer=True\n",
    "    )\n",
    "    # Trích xuất điểm trung vị\n",
    "    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}\n",
    "    return {k: round(v, 4) for k, v in result.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForSeq2Seq\n",
    "\n",
    "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_datasets = tokenized_datasets.remove_columns(\n",
    "    books_dataset[\"train\"].column_names\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'input_ids': tensor([[  1494,    259,   8622,    390,    259,    262,   2316,   3435,    955,\n",
       "            772,    281,    772,   1617,    263,    305,  14701,    260,   1385,\n",
       "           3031,    259,  24146,    332,   1037,    259,  43906,    305,    336,\n",
       "            260,      1,      0,      0,      0,      0,      0,      0],\n",
       "        [   259,  27531,  13483,    259,   7505,    260, 112240,  15192,    305,\n",
       "          53198,    276,    259,  74060,    263,    260,    459,  25640,    776,\n",
       "           2119,    336,    259,   2220,    259,  18896,    288,   4906,    288,\n",
       "           1037,   3931,    260,   7083, 101476,   1143,    260,      1]]), 'labels': tensor([[ 7483,   259,  2364, 15695,     1,  -100],\n",
       "        [  259, 27531, 13483,   259,  7505,     1]]), 'decoder_input_ids': tensor([[    0,  7483,   259,  2364, 15695,     1],\n",
       "        [    0,   259, 27531, 13483,   259,  7505]])}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "features = [tokenized_datasets[\"train\"][i] for i in range(2)]\n",
    "data_collator(features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainer\n",
    "\n",
    "trainer = Seq2SeqTrainer(\n",
    "    model,\n",
    "    args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 3.028524398803711,\n",
       " 'eval_rouge1': 16.9728,\n",
       " 'eval_rouge2': 8.2969,\n",
       " 'eval_rougeL': 16.8366,\n",
       " 'eval_rougeLsum': 16.851,\n",
       " 'eval_gen_len': 10.1597,\n",
       " 'eval_runtime': 6.1054,\n",
       " 'eval_samples_per_second': 38.982,\n",
       " 'eval_steps_per_second': 4.914}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_datasets.set_format(\"torch\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "batch_size = 8\n",
    "train_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"train\"],\n",
    "    shuffle=True,\n",
    "    collate_fn=data_collator,\n",
    "    batch_size=batch_size,\n",
    ")\n",
    "eval_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"validation\"], collate_fn=data_collator, batch_size=batch_size\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import AdamW\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=2e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "\n",
    "accelerator = Accelerator()\n",
    "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
    "    model, optimizer, train_dataloader, eval_dataloader\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import get_scheduler\n",
    "\n",
    "num_train_epochs = 10\n",
    "num_update_steps_per_epoch = len(train_dataloader)\n",
    "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n",
    "\n",
    "lr_scheduler = get_scheduler(\n",
    "    \"linear\",\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=num_training_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def postprocess_text(preds, labels):\n",
    "    preds = [pred.strip() for pred in preds]\n",
    "    labels = [label.strip() for label in labels]\n",
    "\n",
    "    # ROUGE kì vọng dòng mới sau mỗi câu\n",
    "    preds = [\"\\n\".join(nltk.sent_tokenize(pred)) for pred in preds]\n",
    "    labels = [\"\\n\".join(nltk.sent_tokenize(label)) for label in labels]\n",
    "\n",
    "    return preds, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'lewtun/mt5-finetuned-amazon-en-es-accelerate'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from huggingface_hub import get_full_repo_name\n",
    "\n",
    "model_name = \"test-bert-finetuned-squad-accelerate\"\n",
    "repo_name = get_full_repo_name(model_name)\n",
    "repo_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import Repository\n",
    "\n",
    "output_dir = \"results-mt5-finetuned-squad-accelerate\"\n",
    "repo = Repository(output_dir, clone_from=repo_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Epoch 0: {'rouge1': 5.6351, 'rouge2': 1.1625, 'rougeL': 5.4866, 'rougeLsum': 5.5005}\n",
       "Epoch 1: {'rouge1': 9.8646, 'rouge2': 3.4106, 'rougeL': 9.9439, 'rougeLsum': 9.9306}\n",
       "Epoch 2: {'rouge1': 11.0872, 'rouge2': 3.3273, 'rougeL': 11.0508, 'rougeLsum': 10.9468}\n",
       "Epoch 3: {'rouge1': 11.8587, 'rouge2': 4.8167, 'rougeL': 11.7986, 'rougeLsum': 11.7518}\n",
       "Epoch 4: {'rouge1': 12.9842, 'rouge2': 5.5887, 'rougeL': 12.7546, 'rougeLsum': 12.7029}\n",
       "Epoch 5: {'rouge1': 13.4628, 'rouge2': 6.4598, 'rougeL': 13.312, 'rougeLsum': 13.2913}\n",
       "Epoch 6: {'rouge1': 12.9131, 'rouge2': 5.8914, 'rougeL': 12.6896, 'rougeLsum': 12.5701}\n",
       "Epoch 7: {'rouge1': 13.3079, 'rouge2': 6.2994, 'rougeL': 13.1536, 'rougeLsum': 13.1194}\n",
       "Epoch 8: {'rouge1': 13.96, 'rouge2': 6.5998, 'rougeL': 13.9123, 'rougeLsum': 13.7744}\n",
       "Epoch 9: {'rouge1': 14.1192, 'rouge2': 7.0059, 'rougeL': 14.1172, 'rougeLsum': 13.9509}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "progress_bar = tqdm(range(num_training_steps))\n",
    "\n",
    "for epoch in range(num_train_epochs):\n",
    "    # Huấn luyện\n",
    "    model.train()\n",
    "    for step, batch in enumerate(train_dataloader):\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        accelerator.backward(loss)\n",
    "\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "        progress_bar.update(1)\n",
    "\n",
    "    # Đánh giá\n",
    "    model.eval()\n",
    "    for step, batch in enumerate(eval_dataloader):\n",
    "        with torch.no_grad():\n",
    "            generated_tokens = accelerator.unwrap_model(model).generate(\n",
    "                batch[\"input_ids\"],\n",
    "                attention_mask=batch[\"attention_mask\"],\n",
    "            )\n",
    "\n",
    "            generated_tokens = accelerator.pad_across_processes(\n",
    "                generated_tokens, dim=1, pad_index=tokenizer.pad_token_id\n",
    "            )\n",
    "            labels = batch[\"labels\"]\n",
    "\n",
    "            # Nếu ta không đệm đến độ giải tối đa, ta cần đệm cả nhãn nữa\n",
    "            labels = accelerator.pad_across_processes(\n",
    "                batch[\"labels\"], dim=1, pad_index=tokenizer.pad_token_id\n",
    "            )\n",
    "\n",
    "            generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()\n",
    "            labels = accelerator.gather(labels).cpu().numpy()\n",
    "\n",
    "            # Thay -100 ở nhãn vì ta không thể  giải mã chúng\n",
    "            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "            if isinstance(generated_tokens, tuple):\n",
    "                generated_tokens = generated_tokens[0]\n",
    "            decoded_preds = tokenizer.batch_decode(\n",
    "                generated_tokens, skip_special_tokens=True\n",
    "            )\n",
    "            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "\n",
    "            decoded_preds, decoded_labels = postprocess_text(\n",
    "                decoded_preds, decoded_labels\n",
    "            )\n",
    "\n",
    "            rouge_score.add_batch(predictions=decoded_preds, references=decoded_labels)\n",
    "\n",
    "    # Tính toán các chỉ số\n",
    "    result = rouge_score.compute()\n",
    "    # Trích xuất điểm trung vị ROUGE\n",
    "    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}\n",
    "    result = {k: round(v, 4) for k, v in result.items()}\n",
    "    print(f\"Epoch {epoch}:\", result)\n",
    "\n",
    "    # Lưu và tải\n",
    "    accelerator.wait_for_everyone()\n",
    "    unwrapped_model = accelerator.unwrap_model(model)\n",
    "    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n",
    "    if accelerator.is_main_process:\n",
    "        tokenizer.save_pretrained(output_dir)\n",
    "        repo.push_to_hub(\n",
    "            commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "hub_model_id = \"huggingface-course/mt5-small-finetuned-amazon-en-es\"\n",
    "summarizer = pipeline(\"summarization\", model=hub_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_summary(idx):\n",
    "    review = books_dataset[\"test\"][idx][\"review_body\"]\n",
    "    title = books_dataset[\"test\"][idx][\"review_title\"]\n",
    "    summary = summarizer(books_dataset[\"test\"][idx][\"review_body\"])[0][\"summary_text\"]\n",
    "    print(f\"'>>> Review: {review}'\")\n",
    "    print(f\"\\n'>>> Title: {title}'\")\n",
    "    print(f\"\\n'>>> Summary: {summary}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>>> Review: Nothing special at all about this product... the book is too small and stiff and hard to write in. The huge sticker on the back doesn’t come off and looks super tacky. I would not purchase this again. I could have just bought a journal from the dollar store and it would be basically the same thing. It’s also really expensive for what it is.'\n",
       "\n",
       "'>>> Title: Not impressed at all... buy something else'\n",
       "\n",
       "'>>> Summary: Nothing special at all about this product'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print_summary(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'>>> Review: Es una trilogia que se hace muy facil de leer. Me ha gustado, no me esperaba el final para nada'\n",
       "\n",
       "'>>> Title: Buena literatura para adolescentes'\n",
       "\n",
       "'>>> Summary: Muy facil de leer'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print_summary(0)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Tóm tắt (PyTorch)",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
